import sys
sys.path.append("../SEV/")
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pacmap
from FCMCluster import FuzzyCMeans
from Encoder import DataEncoder
from data_loader import data_loader
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier

# load the dataset
datasets = ["adult","german","compas","diabetes","fico","mimic","headline1","headline2","headline3","headline_total"]
n_neighbors = [None,20,20,20,None,None,None,None,None,None]
n_class = [7,3,5,4,4,4,3,2,3,2]
ms = [1.01,1.01,2,1.01,3,1.01,1.01,1.01,1.01,1.01]

selected_datasets = ["adult","german","compas","diabetes","mimic","headline_total"]

dataset_name_map = {"adult":"Adult","german":"German Credit","compas":"COMPAS","diabetes":"Diabetes","fico":"FICO","mimic":"MIMIC III","headline1":"Headline1","headline2":"Headline2","headline3":"Headline3","headline_total":"Headline News"}

# prepare the figure for the results
fig,ax = plt.subplots(2,3,figsize=(15,10))

for ind,dataset in enumerate(selected_datasets):
    if dataset not in selected_datasets:
        continue
    X, y, X_neg = data_loader(dataset)
    encoder = DataEncoder(standard=True)
    encoder.fit(X_neg)
    encoded_X = encoder.transform(X)
    encoded_X_neg = encoder.transform(X_neg)
    print("Working on the dataset {}".format(dataset))
    # do the embedding
    pacmapper = pacmap.PaCMAP(n_components=2, n_neighbors=n_neighbors[ind], MN_ratio=1, FP_ratio=2.0)
    pacmapper.fit(encoded_X_neg,init="pca")
    X_embedded = pacmapper.transform(encoded_X,encoded_X_neg)
    X_embedded_neg = X_embedded[y==0]
    X_embedded_pos = X_embedded[y==1]

    model = GradientBoostingClassifier(n_estimators=200,max_depth=3,random_state=42)
    model.fit(encoded_X,y)
    # do the clustering
    fcm = FuzzyCMeans(model,n_clusters=n_class[ind],m=3)
    fcm.fit(X_embedded_neg,encoded_X_neg)
    labels = fcm.predict(X_embedded_neg,encoded_X_neg)
    predicted_labels = model.predict(encoded_X_neg)

    print("The count of each label is")
    print(pd.Series(labels).value_counts())

    X_med_embedded = []

    # check if all the X_embedding_neg are predicted as negative
    for i in range(n_class[ind]):
        X_med = np.median(encoded_X_neg[(labels==i)&(predicted_labels==0)],axis=0)
        X_med_embedded.append(pacmapper.transform(X_med.reshape(1,-1),encoded_X_neg))
        if model.predict(X_med.reshape(1,-1)) != 0:
            print("The median of the cluster {} is predicted as positive".format(i))
            print(model.predict_proba(X_med.reshape(1,-1)))
            print("The median is {}".format(X_med))
    
    X_med_embedded = np.array(X_med_embedded).reshape(-1,2)


    # # plot the results
    # plt.figure()
    # from matplotlib import colorbar
    # from matplotlib import cm
    # cmap = cm.get_cmap("Pastel1")
    # colors = cmap(np.linspace(0,0.8,n_class[ind]))
    # plt.scatter(X_embedded_pos[:,0],X_embedded_pos[:,1],c="gray",s=10,alpha=0.1)
    # for i in range(n_class[ind]):
    #     plt.scatter(X_embedded_neg[labels==i,0],X_embedded_neg[labels==i,1],label="Class {}".format(i),c=colors[i],s=10,alpha=0.7)
    # # plt.scatter(X_embedded_neg[:,0],X_embedded_neg[:,1],c=labels,cmap="Pastel1",s=10,alpha=0.5)
    # plt.title("The embedding of the dataset {}".format(dataset))
    # plt.legend()
    # plt.savefig("../Results/figures/{}_embedding.png".format(dataset))


    # plot the results
    plt.subplot(2,3,ind+1)
    from matplotlib import colorbar
    from matplotlib import cm
    cmap = cm.get_cmap("Pastel1")
    colors = cmap(np.linspace(0,0.7,n_class[ind]))
    plt.scatter(X_embedded_pos[:,0],X_embedded_pos[:,1],c="gray",s=10,alpha=0.1)
    for i in range(n_class[ind]):
        plt.scatter(X_embedded_neg[labels==i,0],X_embedded_neg[labels==i,1],c=colors[i],s=10,alpha=0.7)
    plt.scatter(X_med_embedded[:,0],X_med_embedded[:,1],marker="*",s=100,c="blue")
    # remove the x,y axis
    plt.xticks([])
    plt.yticks([])
    plt.title("{}:{} Classes".format(dataset_name_map[dataset],n_class[ind]))
    # plt.legend()
plt.savefig("../Results/figures/Clustering_Results_new.png")


